/* Copyright (c) 2022, Arm Limited and Contributors. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

/* This example uses mbedtls to GET a page from example.com and dumps the HTML to serial.
 */

#define PW_LOG_MODULE_NAME "main"

#include "iotsdk/ip_network_api.h"

#include "Driver_USART.h"
#include "agent_message_processor.h"
#include "aws_clientcredential.h"
#include "aws_clientcredential_keys.h"
#include "aws_network_manager/aws_network_manager.h"
#include "cmsis_os2.h"
#include "core_mqtt.h"
#include "core_mqtt_agent.h"
#include "core_mqtt_agent_message_interface.h"
#include "core_mqtt_config.h"
#include "core_mqtt_serializer.h"
#include "mbedtls/platform.h"
#include "message_handler.h"
#include "mqtt_subscription_manager.h"
#include "pw_log/log.h"

#include <inttypes.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>

#define DEMO_PUBLISH_TOPIC              "OpenIoTSDK/demo"
#define DEMO_PAYLOAD                    "Hello World!"
#define MQTT_AGENT_COMMAND_QUEUE_LENGTH (25)

/**
 * @brief The maximum amount of time in milliseconds to wait for the commands
 * to be posted to the MQTT agent should the MQTT agent's command queue be full.
 * Tasks wait in the Blocked state, so don't use any CPU time.
 */
#define MQTT_AGENT_SEND_BLOCK_TIME_MS (3000U)

struct MQTTAgentCommandContext {
    MQTTStatus_t xReturnStatus;
    osThreadId_t xTaskToNotify;
    void *pArgs;
};

const unsigned char drbg_str[] = "Open IoT SDK - AWS Client example";

static const AwsNetworkContextConfig_t networkConfig = {.server_name = clientcredentialMQTT_BROKER_ENDPOINT,
                                                        .server_certificate = keyCA1_ROOT_CERTIFICATE_PEM,
                                                        .client_certificate = keyCLIENT_CERTIFICATE_PEM,
                                                        .client_private_key = keyCLIENT_PRIVATE_KEY_PEM,
                                                        .client_public_key = keyCLIENT_PUBLIC_KEY_PEM,
                                                        .drbg = drbg_str,
                                                        .drbg_len = sizeof(drbg_str),
                                                        .send_timeout = 500,
                                                        .receive_timeout = 500,
                                                        .retry_max_attempts = 5,
                                                        .backoff_max_delay = 5000,
                                                        .backoff_base = 500,
                                                        .port = clientcredentialMQTT_BROKER_PORT};

static NetworkContext_t network_context = {0};

static TransportInterface_t mqtt_transport = {
    .send = AwsNetwork_send, .recv = AwsNetwork_recv, .pNetworkContext = &network_context};

static uint8_t buffer[1024] = {0};
static MQTTFixedBuffer_t mqttBuffer = {
    .pBuffer = buffer,
    .size = sizeof(buffer),
};

static const MQTTConnectInfo_t mqtt_connect_info = {
    .cleanSession = true,
    .pClientIdentifier = clientcredentialIOT_THING_NAME,
    .clientIdentifierLength = (uint16_t)(sizeof(clientcredentialIOT_THING_NAME) - 1),
    .keepAliveSeconds = 60U,
};

static MQTTAgentContext_t mqtt_context = {0};

typedef enum net_event_t { NET_EVENT_NETWORK_UP, NET_EVENT_NETWORK_DOWN, NET_EVENT_NONE } net_event_t;
typedef struct {
    net_event_t event;
    int32_t return_code;
} net_msg_t;

static osMessageQueueId_t net_msg_queue = NULL;

static ARM_DRIVER_USART *serial = NULL;
extern ARM_DRIVER_USART *get_example_serial();

// Events sent and received by the MQTT receiving thread.
#define EVENT_TERMINATE_MQTT_THREAD   (1 << 0)
#define EVENT_MQTT_THREAD_TERMINATING (1 << 1)
#define EVENT_MQTT_COMMAND_COMPLETE   (1 << 2)

// Shared event flags for the MQTT reception thread
osEventFlagsId_t mqtt_recv_thread_events = NULL;

static void serial_setup();
static int mbedtls_platform_example_nv_seed_read(unsigned char *buf, size_t buf_len);
static int mbedtls_platform_example_nv_seed_write(unsigned char *buf, size_t buf_len);
static void network_event_callback(const network_state_event_args_t *event_args);
static void app_task(void *arg);
static uint32_t get_time_ms(void);
static void mqtt_background_process(void *arg);
static int connect_mqtt(NetworkContext_t *ctx);
static void on_incoming_mqtt_publish_message(struct MQTTAgentContext *context,
                                             uint16_t packet_id,
                                             MQTTPublishInfo_t *incoming_packet_info);
static void on_command_complete(MQTTAgentCommandContext_t *context, MQTTAgentReturnInfo_t *return_info);

// Initialize the system then handover to app_task
int main(void)
{
    serial_setup();

    osStatus_t res = osKernelInitialize();
    if (res != osOK) {
        return EXIT_FAILURE;
    }

    // Initialization of PW log is done after the kernel initialization because
    // it requires a lock
    pw_log_cmsis_driver_init(serial);

    net_msg_queue = osMessageQueueNew(10, sizeof(net_msg_t), NULL);
    if (!net_msg_queue) {
        PW_LOG_ERROR("Failed to create a net msg queue");
        return EXIT_FAILURE;
    }

    const osThreadAttr_t thread1_attr = {
        .stack_size = 8192 // Create the thread stack with a size of 8192 bytes
    };
    osThreadId_t demo_thread = osThreadNew(app_task, NULL, &thread1_attr);
    if (!demo_thread) {
        PW_LOG_ERROR("Failed to create thread");
        return EXIT_FAILURE;
    }

    PW_LOG_INFO("Initialising network");
    res = start_network_task(network_event_callback, 0);
    if (res != osOK) {
        PW_LOG_ERROR("Failed to start lwip");
        return EXIT_FAILURE;
    }

    osKernelState_t state = osKernelGetState();
    if (state != osKernelReady) {
        PW_LOG_ERROR("Kernel not ready");
        return EXIT_FAILURE;
    }

    // Configure Mbed TLS
    mbedtls_threading_set_cmsis_rtos();
    int err =
        mbedtls_platform_set_nv_seed(mbedtls_platform_example_nv_seed_read, mbedtls_platform_example_nv_seed_write);
    if (err) {
        PW_LOG_ERROR("Failed to initialize NV seed functions");
        return EXIT_FAILURE;
    }

    PW_LOG_INFO("Starting kernel");
    res = osKernelStart();
    if (res != osOK) {
        PW_LOG_ERROR("osKernelStart failed: %d", res);
        return EXIT_FAILURE;
    }

    return 0;
}

static void app_task(void *arg)
{
    (void)arg;

    // Wait for the connection to be established
    PW_LOG_INFO("Awaiting network connection");
    while (1) {
        net_msg_t msg;
        if (osMessageQueueGet(net_msg_queue, &msg, NULL, 1000) != osOK) {
            msg.event = NET_EVENT_NONE;
        }

        if (msg.event == NET_EVENT_NETWORK_UP) {
            PW_LOG_INFO("Network connection enabled");
            break;
        } else if (msg.event == NET_EVENT_NETWORK_DOWN) {
            PW_LOG_DEBUG("Network is not enabled");
        }
    }

    Agent_InitializePool();

    // Initialize the MQTT
    PW_LOG_INFO("Initialising MQTT connection");

    MQTTAgentMessageContext_t command_queue = {0};
    command_queue.queue = osMessageQueueNew(MQTT_AGENT_COMMAND_QUEUE_LENGTH, sizeof(MQTTAgentCommand_t *), NULL);
    MQTTAgentMessageInterface_t message_interface = {
        .pMsgCtx = &command_queue,
        .recv = Agent_MessageReceive,
        .send = Agent_MessageSend,
        .getCommand = Agent_GetCommand,
        .releaseCommand = Agent_ReleaseCommand,
    };
    SubscriptionElement_t global_subscription_list[SUBSCRIPTION_MANAGER_MAX_SUBSCRIPTIONS];
    MQTTStatus_t mqtt_status = MQTTAgent_Init(&mqtt_context,
                                              &message_interface,
                                              &mqttBuffer,
                                              &mqtt_transport,
                                              &get_time_ms,
                                              &on_incoming_mqtt_publish_message,
                                              global_subscription_list);
    if (mqtt_status != MQTTSuccess) {
        PW_LOG_ERROR("MQTT Init failed - MQTT status = %d", mqtt_status);
        return;
    }

    PW_LOG_INFO("Establishing TLS connection");
    AwsNetwork_init(&network_context, &networkConfig);
    int res = AwsNetwork_connect(&network_context, connect_mqtt);
    if (res) {
        PW_LOG_ERROR("Connection to network failed: %d", res);
        AwsNetwork_close(&network_context);
        return;
    }

    // Start MQTT processing thread. This thread receives data from the server
    // and trigger the mqttEventCallBack.
    // To stop it send and EVENT_TERMINATE_MQTT_THREAD to it.
    PW_LOG_INFO("Starting MQTT receiving thread");
    const osThreadAttr_t mqtt_recv_thread_config = {
        .stack_size = 4096 // Create the thread stack with a size of 4096 bytes
    };
    mqtt_recv_thread_events = osEventFlagsNew(NULL);
    if (mqtt_recv_thread_events == NULL) {
        PW_LOG_ERROR("Failed to create MQTT receiving thead events");
        return;
    }

    osThreadId_t mqtt_recv_thread = osThreadNew(mqtt_background_process, NULL, &mqtt_recv_thread_config);
    if (mqtt_recv_thread == NULL) {
        PW_LOG_ERROR("Failed to create MQTT processing thread");
        return;
    }

    // Publish a packet to the demo topic
    PW_LOG_INFO("Publishing demo message to %s", DEMO_PUBLISH_TOPIC);
    MQTTAgentCommandContext_t command_context = {
        .xTaskToNotify = osThreadGetId(),
        .pArgs = NULL,
        .xReturnStatus = MQTTSendFailed,
    };
    MQTTAgentCommandInfo_t command_info = {
        .cmdCompleteCallback = &on_command_complete,
        .pCmdCompleteCallbackContext = &command_context,
        .blockTimeMs = MQTT_AGENT_SEND_BLOCK_TIME_MS,
    };
    MQTTPublishInfo_t publish_info = {.qos = MQTTQoS1,
                                      .pTopicName = DEMO_PUBLISH_TOPIC,
                                      .topicNameLength = strlen(DEMO_PUBLISH_TOPIC),
                                      .pPayload = DEMO_PAYLOAD,
                                      .payloadLength = strlen(DEMO_PAYLOAD)};
    mqtt_status = MQTTAgent_Publish(&mqtt_context, &publish_info, &command_info);
    if (mqtt_status != MQTTSuccess) {
        PW_LOG_ERROR("MQTT Publish failed - MQTT status = %d", mqtt_status);
        return;
    }

    // Waiting for packet ack
    PW_LOG_INFO("Waiting for packet ack");
    uint32_t previous_event =
        osEventFlagsWait(mqtt_recv_thread_events, EVENT_MQTT_COMMAND_COMPLETE, osFlagsWaitAny, osWaitForever);
    if (previous_event & osFlagsError) {
        PW_LOG_ERROR("Faillure while waiting for ack");
        return;
    }
    osEventFlagsClear(mqtt_recv_thread_events, EVENT_MQTT_COMMAND_COMPLETE);

    PW_LOG_INFO("Disconnecting MQTT");
    mqtt_status = MQTTAgent_Disconnect(&mqtt_context, &command_info);
    if (mqtt_status != MQTTSuccess) {
        PW_LOG_ERROR("MQTT Disconnect failed - MQTT status = %d", mqtt_status);
        return;
    }

    previous_event =
        osEventFlagsWait(mqtt_recv_thread_events, EVENT_MQTT_COMMAND_COMPLETE, osFlagsWaitAny, osWaitForever);
    if (previous_event & osFlagsError) {
        PW_LOG_ERROR("Failure while waiting for MQTT Disconnect");
        return;
    }
    osEventFlagsClear(mqtt_recv_thread_events, EVENT_MQTT_COMMAND_COMPLETE);

    // Terminating MQTT receiving thread
    PW_LOG_INFO("Terminating MQTT receiving thread");
    uint32_t events = osEventFlagsSet(mqtt_recv_thread_events, EVENT_TERMINATE_MQTT_THREAD);
    if (events & osFlagsError) {
        PW_LOG_ERROR("Failed to set flag to terminate receive thread");
        return;
    }

    previous_event =
        osEventFlagsWait(mqtt_recv_thread_events, EVENT_MQTT_THREAD_TERMINATING, osFlagsWaitAny, osWaitForever);
    if (previous_event & osFlagsError) {
        PW_LOG_ERROR("Faillure while waiting for receiving thread termination signal");
        return;
    }
    if (osThreadTerminate(mqtt_recv_thread)) {
        PW_LOG_ERROR("Failed to terminate MQTT receiving thread");
        return;
    }

    // Cleanup of the connection
    PW_LOG_INFO("Closing TLS connection");

    AwsNetwork_close(&network_context);

    PW_LOG_INFO("Demo finished !");
}

// Callback called by AwsNetwork_connect when the TLS connection is established.
static int connect_mqtt(NetworkContext_t *ctx)
{
    /* ctx is not used */
    (void)ctx;

    // Initialize and connect the MQTT instance
    PW_LOG_INFO("Setuping MQTT connection");

    bool session_present = false;
    MQTTStatus_t mqtt_status = MQTT_Connect(&mqtt_context.mqttContext, &mqtt_connect_info, NULL, 0, &session_present);
    if (mqtt_status != MQTTSuccess) {
        PW_LOG_ERROR("MQTT Connection failed - MQTT status = %d", mqtt_status);
        return -1;
    }

    return 0;
}

static void on_command_complete(MQTTAgentCommandContext_t *context, MQTTAgentReturnInfo_t *return_info)
{
    context->xReturnStatus = return_info->returnCode;

    if (return_info->returnCode == MQTTSuccess) {
        osEventFlagsSet(mqtt_recv_thread_events, EVENT_MQTT_COMMAND_COMPLETE);
    }
}

// Callback invoked whenever an MQTT packet is received.
static void on_incoming_mqtt_publish_message(struct MQTTAgentContext *context,
                                             uint16_t packet_id,
                                             MQTTPublishInfo_t *incoming_packet_info)
{
    (void)packet_id;

    if (!SubscriptionManager_HandleIncomingPublishes((SubscriptionElement_t *)context->pIncomingCallbackContext,
                                                     incoming_packet_info)) {
        PW_LOG_WARN("Received an unsolicited publish from topic %s", incoming_packet_info->pTopicName);
    }
}

// Thread function that process incoming MQTT packets
static void mqtt_background_process(void *arg)
{
    /* arg is not used */
    (void)arg;

    MQTTStatus_t status = MQTTAgent_CommandLoop(&mqtt_context);
    if (status != MQTTSuccess && status != MQTTNeedMoreBytes) {
        PW_LOG_ERROR("MQTTAgent_CommandLoop: %d", status);
    } else {
        PW_LOG_DEBUG("MQTTAgent_CommandLoop done");
    }

    // Continue processing until the thread is instructed to stop
    while (!(osEventFlagsGet(mqtt_recv_thread_events) & EVENT_TERMINATE_MQTT_THREAD)) {
    }

    // signal the thread is terminating and sleep
    osEventFlagsSet(mqtt_recv_thread_events, EVENT_MQTT_THREAD_TERMINATING);
    osDelay(osWaitForever);
}

static uint32_t get_time_ms(void)
{
    uint64_t time_ms = (osKernelGetTickCount() * 1000) / osKernelGetTickFreq();
    if ((osKernelGetTickCount() * 1000) % osKernelGetTickFreq()) {
        time_ms += 1;
    }
    return (uint32_t)time_ms;
}

static void serial_setup()
{
    serial = get_example_serial();
    if ((serial->Initialize(NULL) != ARM_DRIVER_OK) || (serial->PowerControl(ARM_POWER_FULL) != ARM_DRIVER_OK)
        || (serial->Control(ARM_USART_MODE_ASYNCHRONOUS, 115200) != ARM_DRIVER_OK)) {
        return;
    }
    /* Some drivers have TX and RX enabled by default and lacks option to enable/disable them. */
    int ret = serial->Control(ARM_USART_CONTROL_TX, 1);
    if (ret != ARM_DRIVER_OK && ret != ARM_DRIVER_ERROR_UNSUPPORTED) {
        return;
    }
    ret = serial->Control(ARM_USART_CONTROL_RX, 1);
    if (ret != ARM_DRIVER_OK && ret != ARM_DRIVER_ERROR_UNSUPPORTED) {
        return;
    }
}

static int mbedtls_platform_example_nv_seed_read(unsigned char *buf, size_t buf_len)
{
    if (buf == NULL) {
        return (-1);
    }
    memset(buf, 0xA5, buf_len);
    return 0;
}

static int mbedtls_platform_example_nv_seed_write(unsigned char *buf, size_t buf_len)
{
    /* buf and buf_len are not used*/
    (void)buf;
    (void)buf_len;

    return 0;
}

/** This callback is called by the ip network task. It translates from a network event code
 * to our app message queue code and sends the event to the main app task.
 *
 * @param event network up or down event.
 */
static void network_event_callback(const network_state_event_args_t *event_args)
{
    const net_msg_t msg = {.event = (event_args->event == NETWORK_UP) ? NET_EVENT_NETWORK_UP : NET_EVENT_NETWORK_DOWN};
    if (osMessageQueuePut(net_msg_queue, &msg, 0, 0) != osOK) {
        PW_LOG_WARN("Failed to send message to net_msg_queue");
    }
}
